import pandas as pd
import numpy as np
from transformers import RobertaTokenizer, RobertaModel
import pickle
from pathlib import Path
from tqdm.auto import tqdm
import re
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
import json
from multiprocessing import Pool

with open('./SKMT_lib_v2/word_root_20231210_sorted', 'rb') as file:
    word_root = pickle.load(file)
korenoveSlova = set(word_root.keys())

files_file_path="./tokenizers/"


with open(files_file_path+"kodovanie.json", "r", encoding="utf-8") as f:
    dictionary = json.load(f)

def dekoduj(tokens):
    decoded_tokens = []
    for token in tokens:
        for k, v in dictionary.items():
            if k in token:
                token = token.replace(k, v)
        decoded_tokens.append(token)
    return decoded_tokens

tokenizer_SB = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
tokenizer_pureBPE = RobertaTokenizer.from_pretrained(files_file_path+'pureBPE')
from SKMT_lib_v2.SKMT_BPE import *
tokenizer_skmt = SKMorfoTokenizer()

def korene_in_slovo(korene, tokens):
    regex = re.compile(korene)
    if re.search(regex, tokens) is not None:
        return True
    return False

def najdi_korene(tokens, slovo):
    pattern = re.compile(r'§{([^}]+)}§|([^§{}]+)')
    korene = []
    for match in pattern.finditer(slovo):
        inside_brackets, outside_brackets = match.groups()
        if inside_brackets is not None:
            korene.append(inside_brackets)
       
    tokens = ",".join(tokens)
    korene = ".*".join(korene)
    return korene_in_slovo(korene, tokens)

def spracuj_slova(platne_slova, platne_slova_korenove, tokenizer):
    tokens_slova_SB = []
    tokens_korenove_SB = []
    pocet_nespravnych_korenovych_slov = 0
    
    for slovo in platne_slova:
        if tokenizer == "sb":
            tokens_slovo_SB = dekoduj(tokenizer_SB.tokenize(" "+slovo))
        elif tokenizer == "purebpe":
            tokens_slovo_SB = dekoduj(tokenizer_pureBPE.tokenize(" "+slovo))
        elif tokenizer == "skmt":
            tokens_slovo_SB = dekoduj(tokenizer_skmt.tokenize(" "+slovo, max_length=None, return_tensors=None, return_subword=False))
        tokens_slova_SB.extend(tokens_slovo_SB)
        if slovo in platne_slova_korenove:
            tokens_korenove_SB.extend(tokens_slovo_SB)
            if najdi_korene(tokens_slovo_SB, word_root[slovo]) == False:
                pocet_nespravnych_korenovych_slov +=1
        
    avg_count_slova_SB = round(len(tokens_slova_SB) / len(platne_slova),4)
    avg_count_korenove_SB = round(len(tokens_korenove_SB) / len(platne_slova_korenove),4)
    
    avg_len_slova_SB = round(sum(len(x.strip("Ġ")) for x in tokens_slova_SB) / len(tokens_slova_SB),4)
    avg_len_korenove_SB = round(sum(len(x.strip("Ġ")) for x in tokens_korenove_SB) / len(tokens_korenove_SB),4)
    
    koeficient_korenov = round(1 - (pocet_nespravnych_korenovych_slov / len(platne_slova_korenove)),4)
    
    
    return avg_count_slova_SB, avg_count_korenove_SB, avg_len_slova_SB, avg_len_korenove_SB, koeficient_korenov

def over_line(line):
    special_chars = "jžxďqitürpľuknŕemfšřýťhzčäwáécóösyoĺěvôdlňabígú"
    pattern = f"[^\\w{special_chars}]+"  # \\w zodpovedá alfanumerickým znakom
    platne_slova = re.split(pattern, line)
    platne_slova = {word for word in platne_slova if word}
    
    platne_slova_korenove = {platne_slovo for platne_slovo in platne_slova if platne_slovo in korenoveSlova}
    
    len_tokens_SB = round(len(tokenizer_SB.tokenize(line)),2)
    len_tokens_pureBPE = round(len(tokenizer_pureBPE.tokenize(line)),2)
    len_tokens_skmt = round(len(tokenizer_skmt.tokenize(line, max_length=None, return_tensors=None, return_subword=False)),2)
    
    result_sb = spracuj_slova(platne_slova, platne_slova_korenove, "sb")
    result_purebpe = spracuj_slova(platne_slova, platne_slova_korenove, "purebpe")
    result_skmt = spracuj_slova(platne_slova, platne_slova_korenove, "skmt")
    
    return [len(platne_slova), len(platne_slova_korenove), len_tokens_SB, result_sb[0], result_sb[1], result_sb[2], result_sb[3], result_sb[4], len_tokens_pureBPE, result_purebpe[0], result_purebpe[1], result_purebpe[2], result_purebpe[3], result_purebpe[4], len_tokens_skmt, result_skmt[0], result_skmt[1], result_skmt[2], result_skmt[3], result_skmt[4]]
    
def clean_lines(file_path):
    file_count = str(file_path).split('_')[-1].split('.')[0]
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()
    results = []
        
    with tqdm(total=len(lines), desc=f"Processing {file_path.name}", unit="line") as pbar:
        for index, line in enumerate(lines):
            result = [file_count, index] + over_line(line)
            results.append(result)
            pbar.update(1)
            
    return results

results = []

# Get file paths
paths = [Path(x) for x in Path('../cleaned_500k').glob('**/*.txt') if "ipynb_checkpoints" not in str(x)]
paths = sorted(paths, key=lambda x: int(x.name.split('_')[-1].split('.')[0]))
# paths = paths[:32]

# Define the number of processes to use
num_processes = 100  # Adjust this as needed

with Pool(processes=num_processes) as pool, tqdm(total=len(paths)) as pbar:
    for res in pool.imap_unordered(clean_lines, [file_path for file_count, file_path in enumerate(paths)]):
        results.extend(res)
        pbar.update(1)
        
        
column_headers = ['File_Count', 'Line_Number', "Word count", "Word root count", "SB0", "SB1", "SB2", "SB3", "SB4", "SB5", "pureBPE0", "pureBPE1", "pureBPE2", "pureBPE3", "pureBPE4", "pureBPE5", "SKMT0", "SKMT1", "SKMT2", "SKMT3", "SKMT4", "SKMT5"]
df = pd.DataFrame(results, columns=column_headers)
df = df.sort_values(by=['File_Count', 'Line_Number'])

df.to_csv('output_500k.csv', index=False, encoding='utf-8')
df.to_excel('output_500k.xlsx', index=False, encoding='utf-8')
        

            
